package mw

import (
	"context"

	"github.com/cloudwego/hertz/pkg/app"
	"gorm.io/gorm"
)

const DB_CTX_KEY = "DB"

func InjectDB(db *gorm.DB) app.HandlerFunc {
	return func(ctx context.Context, c *app.RequestContext) {
		c.Set(DB_CTX_KEY, db.WithContext(ctx))
		c.Next(ctx)
	}
}

func GetDB(c *app.RequestContext) *gorm.DB {
	d, _ := c.Get(DB_CTX_KEY)
	db, ok := d.(*gorm.DB)
	if !ok {
		panic("GetDB: Invalid Context")
	}
	return db
}
